[megatron] fix: enable_routing_replay fails with MLATransformerConfig…#5884
Conversation
… (mbridge) When using R3 router replay with DeepSeek models via vanilla mbridge, _build_tf_config() passes enable_routing_replay=True through bridge.set_extra_args(**override_transformer_config). This calls MLATransformerConfig.__init__() which doesn't accept this kwarg: TypeError: MLATransformerConfig.__init__() got an unexpected keyword argument 'enable_routing_replay' MLATransformerConfig is a dataclass that generates its own __init__ and doesn't call the patched TransformerConfig.__init__ that the router replay patch modifies to accept enable_routing_replay. Fix: remove enable_routing_replay from override_transformer_config and set it directly as an attribute on tf_config/provider after construction. Tested on 10B DeepSeek MoE (megatron + sglang, 32x H100, R3 mode).
There was a problem hiding this comment.
Code Review
This pull request refactors how the 'enable_routing_replay' flag is handled in 'transformer_impl.py' by moving its assignment from the 'override_transformer_config' dictionary to direct attribute setting on the provider and configuration objects. This change avoids potential 'TypeError' issues with dataclass-based configurations that do not support arbitrary keyword arguments. I have suggested adding a defensive 'pop' operation to ensure the flag is removed from the override dictionary if present, preventing unexpected initialization errors.
I am having trouble creating individual review comments. Click here to see my feedback.
verl/workers/engine/megatron/transformer_impl.py (166-167)
To ensure that enable_routing_replay does not cause a TypeError when initializing dataclass-based configurations (like MLATransformerConfig), it should be explicitly removed from the override_transformer_config dictionary. While the logic that was adding it has been removed, it could still be present if provided via self.engine_config.override_transformer_config. The attribute is now correctly set directly on the config or provider objects later in the function.
override_transformer_config.pop("enable_routing_replay", None)verl-project#5884) ### What does this PR do? Fixes R3 router replay crash when used with DeepSeek models via vanilla mbridge (`MLATransformerConfig`). `_build_tf_config()` passes `enable_routing_replay=True` through `bridge.set_extra_args(**override_transformer_config)`, which calls `MLATransformerConfig.__init__()`. But `MLATransformerConfig` is a Python dataclass — it generates its own `__init__` and doesn't call the patched `TransformerConfig.__init__` that accepts `enable_routing_replay`. Result: ``` TypeError: MLATransformerConfig.__init__() got an unexpected keyword argument 'enable_routing_replay' ``` This affects any mbridge model config that is a dataclass subclass of `TransformerConfig` (DeepSeek, potentially others). **Fix:** Remove `enable_routing_replay` from `override_transformer_config` dict. Set it directly as an attribute on `tf_config` / `provider` after construction. Related: verl-project#4567 (similar fix for `Qwen3VLTransformerConfig`, still open). This PR is more generic — works for any dataclass config subclass. ### Checklist Before Starting - [x] Search for similar PRs: [enable_routing_replay MLATransformerConfig](https://github.com/verl-project/verl/pulls?q=is%3Apr+enable_routing_replay+MLATransformerConfig) — no existing PR for this specific bug - [x] Format the PR title as `[{modules}] {type}: {description}` ### Test Tested on 10B DeepSeek MoE model (megatron engine + sglang rollout, R3 mode, 32x H100 GPUs, PP=2 TP=2 EP=2): - Without fix: `TypeError: MLATransformerConfig.__init__() got an unexpected keyword argument 'enable_routing_replay'` - With fix: R3 router replay works correctly - `rollout_corr/log_ppl_diff = 0.0003` (sglang and megatron log-probs match) - `rollout_corr/kl = 0.0003` - `actor/grad_norm = 0.44` (stable, normal range) - Training runs for 10+ steps without issues ### API and Usage Example No API changes. Existing R3 router replay config works as before: ```bash actor_rollout_ref.actor.megatron.router_replay.mode=R3 actor_rollout_ref.rollout.enable_rollout_routing_replay=True ``` ### Design & Code Changes Single file change in `verl/workers/engine/megatron/transformer_impl.py`: 1. **Remove** `enable_routing_replay` from `override_transformer_config` dict (was passed to `bridge.set_extra_args()` which forwards to dataclass `__init__`) 2. **Add** `tf_config.enable_routing_replay = True` after config creation (vanilla mbridge path) 3. **Add** `provider.enable_routing_replay = True` after overrides (non-vanilla mbridge path) ### Checklist Before Submitting - [x] Read the Contribute Guide. - [ ] Apply pre-commit checks. - [ ] Add / Update the documentation. — N/A, no doc changes needed. - [ ] Add unit or end-to-end test(s). — Not feasible: requires mbridge + DeepSeek model + multi-GPU setup. The existing `tests/special_e2e/run_ppo_trainer_megatron.sh` with `ROUTING_REPLAY_MODE=R3` covers this path when run on MoE models. - [ ] Once your PR is ready for CI, send a message in the `ci-request` channel.
What does this PR do?
Fixes R3 router replay crash when used with DeepSeek models via vanilla mbridge (
MLATransformerConfig)._build_tf_config()passesenable_routing_replay=Truethroughbridge.set_extra_args(**override_transformer_config), which callsMLATransformerConfig.__init__(). ButMLATransformerConfigis a Python dataclass — it generates its own__init__and doesn't call the patchedTransformerConfig.__init__that acceptsenable_routing_replay. Result:This affects any mbridge model config that is a dataclass subclass of
TransformerConfig(DeepSeek, potentially others).Fix: Remove
enable_routing_replayfromoverride_transformer_configdict. Set it directly as an attribute ontf_config/providerafter construction.Related: #4567 (similar fix for
Qwen3VLTransformerConfig, still open). This PR is more generic — works for any dataclass config subclass.Checklist Before Starting
[{modules}] {type}: {description}Test
Tested on 10B DeepSeek MoE model (megatron engine + sglang rollout, R3 mode, 32x H100 GPUs, PP=2 TP=2 EP=2):
TypeError: MLATransformerConfig.__init__() got an unexpected keyword argument 'enable_routing_replay'rollout_corr/log_ppl_diff = 0.0003(sglang and megatron log-probs match)rollout_corr/kl = 0.0003actor/grad_norm = 0.44(stable, normal range)API and Usage Example
No API changes. Existing R3 router replay config works as before:
Design & Code Changes
Single file change in
verl/workers/engine/megatron/transformer_impl.py:enable_routing_replayfromoverride_transformer_configdict (was passed tobridge.set_extra_args()which forwards to dataclass__init__)tf_config.enable_routing_replay = Trueafter config creation (vanilla mbridge path)provider.enable_routing_replay = Trueafter overrides (non-vanilla mbridge path)Checklist Before Submitting
tests/special_e2e/run_ppo_trainer_megatron.shwithROUTING_REPLAY_MODE=R3covers this path when run on MoE models.ci-requestchannel.